Skip to content

Conversation

@thomasloux
Copy link
Collaborator

Summary

First draft for constraints. Inspiration from ASE.
This is a rather big feature, I prefer to post a first mvp so that we can discuss some implementation and make sure constraints are compatible with Integrators and Optimizers.

Codes implementing constraints

  • JAX-MD: no
  • ASE: yes
  • LAMMPS: yes, but difficult to read (C++) and copy.
  • OpenMM: yes, but only bonds lengths or angles as far as I understand
    For the rest I did not check this much.

Implementation:

  • Constraints: FixAtoms, FixCom
    Change in the code:
  • Add an optional constraints variable in SimState
  • Turn every subclass dataclass(kw_only=True) to prevent errors (positional argument before non pos argument)
  • Adapt MD only, via setter (set_momenta, set_positions)

What's working perfectly

  • Contraints tests implemented
  • All tests pass
  • FixAtoms really fixes the atoms concerned

Future work

  • Adapt optimizer
  • Make sure that FixCom is working well. The Com drift is small, but I would expect an even smaller value. ASE seems to have a smaller value from first tests. I did not yet reproduce their tests as it is based on an optimization task.
  • Implement other constraints
  • make sure everything is compatible with Integrators and Optimizers. Essentially we should apply RATTLE algorithm.
  • Adapt the calculation of temperature to take into account the reduction of degrees of freedom.
  • For FixCom, one wants the center of mass to be fixed in unwrap coordinate. So of course, we need to check after unwrapping. And as a result, I now affect set_momenta, then apply the wrapping on the systems. Actually ASE does not wrap during the simulations, but the wrapping is applied from each calculator forward pass. I think that it is the case for most of TorchSim model implementations. So we may want to stop wrapping during the simulations. Of course if always possible to wrap afterwards.

Implementation discussions

  • I wanted first to have an hidden variable _positions and have a implicit setter and getter. But this is very painful to adapt with _atom_attributes and the copy of state sometimes performed directly accessing the variables using vars(state). At least it is the case if I set the init function to accept an positions argument and not a _positions argument. Actually I think that the setter is clearer to indicate that the constraints are imposed.
  • I decided to add constraints as a global attribute, but prevent list of constraint and looping over the constraints.
  • The disavantage: It's rather difficult to write the constraints and define the constraints over the batch system. For FixAtoms you need to provide the index of the atoms in the batch system. Probably we want to define actually for each system and then batch the constraint. For FixCom, we need to compute the COM in an efficient, which is not super easy to read. It will be even worse if one wants to fix the com of a subgroup in a batch systems.

Checklist

Before a pull request can be merged, the following items must be checked:

  • Doc strings have been added in the Google docstring format.
  • Run ruff on your code.
  • Tests have been added for any new functionality or bug fixes.

We highly recommended installing the prek hooks running in CI locally to speedup the development process. Simply run pip install prek && prek install to install the hooks which will check your code before each commit.

Copy link
Collaborator

@curtischong curtischong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good!

I like how the abstract methods like adjust_positions are flexible enough for constraining specific indices/atoms and for constraining specific axes of movement (which'll probably be added later)

Copy link
Collaborator

@orionarcher orionarcher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice job with this @thomasloux! I think this is most of the way there.

I have a few scattered comments but mostly I have thoughts on the API.

So we may want to stop wrapping during the simulations.

Agreed, this is a longstanding issue noted in #17.

I decided to add constraints as a global attribute, but prevent list of constraint and looping over the constraints.

I think a global attribute is the right approach. It doesn't make sense to concatenate or stack the constraints so atom and system attributes don't make sense. That said I think we'll need to add in some special handling for concatenate and split operations to make sure this doesn't break state handling.

The disavantage: It's rather difficult to write the constraints and define the constraints over the batch system. For FixAtoms you need to provide the index of the atoms in the batch system.

I think the usage pattern should be to define constraints for all the singular states and then concantenate them together, letting TorchSim handle making sure the indices are properly updated. This would also mean only allowing one constraint of each type on a given SimState. We could also imagine letting folks set constraints in ASE and then automatically porting them over.

For FixCom, we need to compute the COM in an efficient, which is not super easy to read. It will be even worse if one wants to fix the com of a subgroup in a batch systems.

Yeah this is a bit of a headache. I'll think on this too.

Broadly, I think the design decisions here are solid and just need a bit more buildout to get to a good API. In particular, I think we'll need a few more additions to this PR:

  1. Constraints update when state is modified. _split_state, _pop_states, _slice_state, and concatenate_states will need to be modified to correctly adjust the indices of the constraints when the state is mutated.
  2. Autovalidation of constraints. I think there should be a validate_constraints method that makes sure there are no overlapping constraints and that all contraints operate within a single system idx (not across multiple batches). This can be called both in the post_init method of SimState and in a set_constraint method if we add one.

I'm currently traveling but am happy to pair on this PR when I am back (Monday) to help with both of the above. This has been on my todo list for a while.

from torch_sim.state import SimState


class FixConstraint(ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is going to be the base class for all constraints I'd favor naming it Constraint

Comment on lines 1162 to 1164
def unwrap_positions(
pos: torch.Tensor, box: torch.Tensor, system_idx: torch.Tensor
) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just used in the tests so I'd suggest moving it into the testing file and making it a private method. Unwrapping coordinates is quite tricky to do right, and I don't want to imply we've done it perfectly by adding it to the public API.

Copy link
Collaborator Author

@thomasloux thomasloux Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now I can do that, but I consider that this implement should be right, at least assuming that the displacement at each step is small enough

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ik this is a helper function, but if we're using it to verify the correctness of other code, it'd be nice to have a test for this

@orionarcher orionarcher added breaking Breaking changes feature Entirely new features, not improvements to existing ones labels Oct 21, 2025
thomasloux and others added 3 commits October 21, 2025 14:46
@thomasloux
Copy link
Collaborator Author

thomasloux commented Oct 21, 2025

@orionarcher thanks for the feedback. I'm starting to adapt for _split_state, _pop_states, _slice_state, and concatenate_states. _filter_attrs_by_mask seems to be slightly too general. Actually as it is accepting both an atom_maskand system_mask, although actually you probably expect only to depend on system_mask. I would suggest to determine the atom_mask in the function. Even if it is slightly less efficient.

EDIT: or find a way to be sure that atom_mask always correspond to the system_mask

@thomasloux
Copy link
Collaborator Author

Hey, I've changed quite a lot the API of the constraints. Actually I've almost reproduced the equivalent of atom_attribute or system_attribute at the level of constraints. The code is still a draft, they are probably better way reorganize the operation, especially for manipulation of states (pop, split, concatenante) which are probably difficult to read and too long.

  • Now support states manipulations
  • I removed wrap
  • Add a fix reference in FixCom (set at the first call)
  • Test for constraints for optimizers
  • Modify calc_kt to support new degrees of freedom

Regarding SystemConstraint, I need to implement a default slice(None) so that the constraint can function in the case the final user directly sets constraint (state.constraint = [FixCom()]). We could accept not to support this.

Also I realize the general_attribute is only supported for now for pbc, which should change soon. We should have constraint in its own category as it's not even propagate in similar ways as the rest of the attributes. This would also allow to have a private _constraint variable, and support the previous statement not to set directly constraint.

More than open for suggestion how to best add constraints.

By the way, in the meantime, we may want to have a dedicate branch so that people can try (a lot of demand for optimization #114). It's probably better to do before having extensive tests and proofs that our optimizers and integrators are compatible with the constraints.

Regarding calc_kt, I suggest to make it a MDState method (like in ASE). So that the user does not forget to add degrees of freedom.

I did not add for now a check of incompatible constraints for now, not sure it is relevant. For instance you may want in the future to fix the length of all atoms in a chain (think of a polymer). Then the index of an atom will appear multiple times.

@janosh janosh marked this pull request as ready for review November 10, 2025 03:24
@janosh
Copy link
Collaborator

janosh commented Nov 10, 2025

@thomasloux great work on this! i'd like to start using symmetry constrained relaxation and would like to help push this over the finish line. if you don't mind i'll push a few minor bug fixes and some more tests?

…nd checks atom indices exist in state if provided + that all constrained atoms belong to same system
- check that constraints work correctly with non-periodic boundaries and in batched states
@thomasloux
Copy link
Collaborator Author

Hey, no problem for me. Actually I need to modify the current implementation of FixCom. I also need to add a warning if someone uses that on NPT integrators, because I'm almost sure the current implementation is not ready (as least for the simple reason that the constraint should be rescaled with the cell).

Copy link
Collaborator Author

@thomasloux thomasloux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check validate_constraints should be changed

This function checks for potential issues like multiple constraints
acting on the same atoms, which could lead to unexpected behavior.
This function checks for:
1. Overlapping atom indices across multiple constraints
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For 2. AtomIndexedConstraints spanning multiple systems:
So actually the current FixAtoms is implemented so that it is supposed to be only one FixAtoms contraints for a batch system. So it's expected to act on multiple systems.

for 1. it's not so clear that it's a problem to have the same atoms affected by multiple constraints. Another precise example of that, take a water molecule, you often want to constraints H-bonds. Then the oxygen atom will be concerned by 2 constraints.

Copy link
Collaborator

@orionarcher orionarcher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, nice work @thomasloux! This will be an awesome feature to have and this PR creates a strong abstraction for the implementation of future constraints. Well done.

The main request I have before merging is that we settle on a nomenclature for the constraint indices. I don't like having indices for AtomConstraint and system_idx for SystemConstraint. We could have indices for both or do atom_idx and system_idx.

Comment on lines +727 to +738
# take into account constraints that are AtomConstraint
filtered_attrs["_constraints"] = [
constraint.select_constraint(atom_mask, system_mask)
for constraint in copy.deepcopy(state.constraints)
]
# Remove any None constraints resulting from selection
filtered_attrs["_constraints"] = [
constraint
for constraint in filtered_attrs["_constraints"]
if constraint is not None
]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this in a single comprehension with an if statement and walrus operator, like below?

@thomasloux
Copy link
Collaborator Author

I already changed that to atom_idx and system_idx!

"""
self.cell = value.mT

def set_positions(self, new_positions: torch.Tensor) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we want people to use this new function, maybe we should add a warning if they don't use set_positions?

@property
def positions(self) -> torch.Tensor:
    """Unit cell following the row vector convention."""
    logger.warn("Directly setting state.positions is deprecated. use set_positions instead")
    return self.positions

Copy link
Collaborator

@curtischong curtischong Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we can make the set_positions function the @Property of def positions instead so it will automatically handle constraints and people don't ever need to know about the set_positions function. We may need to prefix positions with an underscore since it the variable might shadow the @property definition but I'm not sure

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a small but important thing to do because we don't want people to accidentally use the wrong functions to set positions/forces.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh yeah this is important. I agree that having set_positions and positions is a bit confusing.

I would suggest we instead call this constrain_positions which has no arguments and applies the constraints to the existing positions. I don't think setting positions should necessarily incur the application of constraints, and having the naming mixed up is confusing. A separate operation strikes me as the cleanest solution. We could call it apply_positions constraints too.

Whatever we do, let's make sure to propagate it to forces and momenta as well.

So the flow would look like:

state.positions = new_positions
state.constrain_positions()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this set_positions mostly makes sense for under-the-hood algorithms. I don't think you can add @Property on something that is a already a variable of the class, otherwise you just replace it. So your approach would work with a hidden _positions but this start to change a lot the current logic. By the way, you don't always want to apply constraints.
Example: in MD, it's redundant to apply on both forces and momenta. It's probably not wrong, though it may not be this simple for more complicated Constraint like FixBonds, but you do two time the work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok upon reflection my proposed solution would not work. I suggest instead:

  1. Removing the set_positions, set_momenta, and set_forces functions from the states
  2. Instead creating constrain_positions, constrain_momenta, and constrain_forces functions in constraints.py, these would have the following signature and use:
def constrain_positions(positions: torch.Tensor, constraints: list[Constraint], reference_state: ts.SimState) -> torch.Tensor:
        for constraint in (constraints or []):
            constraint.adjust_positions(reference_state, positions)

# then used by
constrain_positions(new_positions, state.constraints, state)
state.positions = new_positions

# of course with a slight adjustment we could also do
constrain_positions(new_positions, state)
state.positions = new_positions

I think removing the constrain_X logic from the SimStates is preferable because it makes it more reusable across the package and isolates the contraint logic in constraints.py.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potentially I would prefer a more verbose
state.set_positions(new_positions, apply_constraints=True). Only reason is having 1 line when you want to set a new position.
But I don't have a strong opinion on any of the different implementations we discussed

Copy link
Collaborator

@orionarcher orionarcher Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree having a one liner is nice. What about:

from constraints import set_constrained_positions

set_constrained_positions(state, new_positions)

I prefer not having two ways to do the same thing and I like having all the constraints logic in constraints. Could call it something else, I'd be ok with simply set_positions too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed my mind. One liner like ASE is a good option when you have a private variable _positions and force final user to use set_positions. In our case, it's not possible. So I think it's better as you said to keep in 2 lines

state.positions = positions
state.constrain_positions()

The code is more verbose but more self explainable as well.
I think I prefer this form as one understand that constraints come from state and modify internal positions, rather than constrain_positions which I think is a bit weird either with the argument/signature constrain_positions(new_positions, state.constraints, state) or constrain_positions(new_positions, state). In one case, it's a bit weird to provide state (but it's needed for some constraints (eg FixAtoms). The second case, the signature is the same as state.constraint_positions().

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me! I could also imagine there could be special states that need their own constraint implementations, so I am happy to leave it on the state. It's a pretty short function, anyhow.

raise ValueError(
"The indices array contains duplicates. "
"Perhaps you want to specify a mask instead, but "
"forgot the mask= keyword."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what this mask= keyword is. but maybe this is leftover from older code. I think we can end this warning at "Perhaps you want to specify a mask instead"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually added support for mask. These init functions were copied from ASE a bit too quickly.

Copy link
Collaborator

@curtischong curtischong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is great! There are a few things I think could be improved but in general (e.g. allowing 0 dof in state.py), but it seems like we have constraints now.

"""
self.cell = value.mT

def set_positions(self, new_positions: torch.Tensor) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh yeah this is important. I agree that having set_positions and positions is a bit confusing.

I would suggest we instead call this constrain_positions which has no arguments and applies the constraints to the existing positions. I don't think setting positions should necessarily incur the application of constraints, and having the naming mixed up is confusing. A separate operation strikes me as the cleanest solution. We could call it apply_positions constraints too.

Whatever we do, let's make sure to propagate it to forces and momenta as well.

So the flow would look like:

state.positions = new_positions
state.constrain_positions()

@curtischong curtischong self-requested a review November 25, 2025 17:49
@orionarcher
Copy link
Collaborator

orionarcher commented Dec 1, 2025

Fantastic job in getting this to the finish line @thomasloux. The only unresolved issue is the question of position/momentum/force assignment with constraints. Happy to be an extra pair of hands on that if desired

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

breaking Breaking changes feature Entirely new features, not improvements to existing ones

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants